{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:23:40.659026433Z",
     "start_time": "2023-06-09T12:23:33.665930760Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-09 14:23:33.930923: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2023-06-09 14:23:34.057324: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2023-06-09 14:23:34.057796: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-06-09 14:23:34.762430: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "2023-06-09 14:23:35.992621: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-06-09 14:23:35.993068: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "2023-06-09 14:23:36.425651: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3488249856 exceeds 10% of free system memory.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "## Loading images and labels\n",
    "(train_ds, train_labels), (test_ds, test_labels) = tfds.load(\n",
    "    \"tf_flowers\",\n",
    "    split=[\"train[:70%]\", \"train[:30%]\"], ## Train test split\n",
    "    batch_size=-1,\n",
    "    as_supervised=True,  # Include labels\n",
    ")\n",
    "\n",
    "## Resizing images\n",
    "train_ds = tf.image.resize(train_ds, (150, 150))\n",
    "test_ds = tf.image.resize(test_ds, (150, 150))\n",
    "\n",
    "## Transforming labels to correct format\n",
    "train_labels = to_categorical(train_labels, num_classes=5)\n",
    "test_labels = to_categorical(test_labels, num_classes=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.vgg16 import VGG16\n",
    "from tensorflow.keras.applications.vgg16 import preprocess_input\n",
    "\n",
    "## Loading VGG16 model\n",
    "base_model = VGG16(weights=\"imagenet\", include_top=False, input_shape=train_ds[0].shape)\n",
    "base_model.trainable = False ## Not trainable weights\n",
    "\n",
    "## Preprocessing input\n",
    "train_ds = preprocess_input(train_ds)\n",
    "test_ds = preprocess_input(test_ds)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:23:41.424220467Z",
     "start_time": "2023-06-09T12:23:40.659745875Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg16\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 150, 150, 3)]     0         \n",
      "                                                                 \n",
      " block1_conv1 (Conv2D)       (None, 150, 150, 64)      1792      \n",
      "                                                                 \n",
      " block1_conv2 (Conv2D)       (None, 150, 150, 64)      36928     \n",
      "                                                                 \n",
      " block1_pool (MaxPooling2D)  (None, 75, 75, 64)        0         \n",
      "                                                                 \n",
      " block2_conv1 (Conv2D)       (None, 75, 75, 128)       73856     \n",
      "                                                                 \n",
      " block2_conv2 (Conv2D)       (None, 75, 75, 128)       147584    \n",
      "                                                                 \n",
      " block2_pool (MaxPooling2D)  (None, 37, 37, 128)       0         \n",
      "                                                                 \n",
      " block3_conv1 (Conv2D)       (None, 37, 37, 256)       295168    \n",
      "                                                                 \n",
      " block3_conv2 (Conv2D)       (None, 37, 37, 256)       590080    \n",
      "                                                                 \n",
      " block3_conv3 (Conv2D)       (None, 37, 37, 256)       590080    \n",
      "                                                                 \n",
      " block3_pool (MaxPooling2D)  (None, 18, 18, 256)       0         \n",
      "                                                                 \n",
      " block4_conv1 (Conv2D)       (None, 18, 18, 512)       1180160   \n",
      "                                                                 \n",
      " block4_conv2 (Conv2D)       (None, 18, 18, 512)       2359808   \n",
      "                                                                 \n",
      " block4_conv3 (Conv2D)       (None, 18, 18, 512)       2359808   \n",
      "                                                                 \n",
      " block4_pool (MaxPooling2D)  (None, 9, 9, 512)         0         \n",
      "                                                                 \n",
      " block5_conv1 (Conv2D)       (None, 9, 9, 512)         2359808   \n",
      "                                                                 \n",
      " block5_conv2 (Conv2D)       (None, 9, 9, 512)         2359808   \n",
      "                                                                 \n",
      " block5_conv3 (Conv2D)       (None, 9, 9, 512)         2359808   \n",
      "                                                                 \n",
      " block5_pool (MaxPooling2D)  (None, 4, 4, 512)         0         \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 14,714,688\n",
      "Trainable params: 0\n",
      "Non-trainable params: 14,714,688\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "base_model.summary()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:23:41.483171630Z",
     "start_time": "2023-06-09T12:23:41.424730309Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CP229110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/WD224010.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/TK221110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/VP214110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CN223100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/MyEDFImports.py:20: RuntimeWarning: Channel names are not unique, found duplicates for: {'CHIN EMG'}. Applying running numbers for duplicates.\n",
      "  raw = mne.io.read_raw_edf(path + \"//\" + name)\n",
      "/home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/MyEDFImports.py:20: RuntimeWarning: Channel names are not unique, found duplicates for: {'CHIN EMG'}. Applying running numbers for duplicates.\n",
      "  raw = mne.io.read_raw_edf(path + \"//\" + name)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LM230010.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/VC209100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LA216100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/DG220020.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/DO223050.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CX230050.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "loading file with nr of windows 2007\n",
      "loading file with nr of windows 1777\n",
      "loading file with nr of windows 1599\n",
      "loading file with nr of windows 1725\n",
      "loading file with nr of windows 1561\n",
      "loading file with nr of windows 1724\n",
      "loading file with nr of windows 1843\n",
      "loading file with nr of windows 1633\n",
      "loading file with nr of windows 1775\n",
      "loading file with nr of windows 1806\n",
      "loading file with nr of windows 1798\n"
     ]
    }
   ],
   "source": [
    "from MyEDFImports import load_all_labels, load_all_data\n",
    "data = load_all_data()\n",
    "labels = load_all_labels()\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:23:50.096572688Z",
     "start_time": "2023-06-09T12:23:41.483950583Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "output = np.ndarray(shape=(len(data),127,127))\n",
    "from MyEDFImports import sizeof\n",
    "sizeof(output)\n",
    "# size is about 2,3 GB"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "import pywt\n",
    "def calculate_dwf(signals, wavelet_name = 'morl', scales = range(1, 128)):\n",
    "    size = len(signals)\n",
    "    output = np.ndarray(shape=(size,len(scales), len(scales)))\n",
    "    for ind in range(size):\n",
    "        if ind % 500 == 0:\n",
    "            print(ind)\n",
    "        signal = signals[ind]\n",
    "        coef, freq = pywt.cwt(signal, scales, wavelet_name)\n",
    "        coef_ = coef[:, :len(scales)]\n",
    "        output[ind] = coef_\n",
    "    return output\n",
    "\n",
    "def three_stages_transform(n: int):\n",
    "    if n == 0:\n",
    "        return 0\n",
    "    if n == 5:\n",
    "        return 2\n",
    "    return 1\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:23:50.130273461Z",
     "start_time": "2023-06-09T12:23:50.098294893Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "500\n",
      "1000\n",
      "1500\n",
      "2000\n",
      "2500\n",
      "3000\n",
      "3500\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "data_short_4000 = data[:4000]\n",
    "cwt_short_4000 = calculate_dwf(data_short_4000)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:42:42.418246999Z",
     "start_time": "2023-06-09T12:25:53.705935573Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "from MyEDFImports import sizeof\n",
    "sizeof(cwt_short_4000)\n",
    "np.save(f'cwt_short_{len(cwt_short_4000)}', cwt_short_4000)\n",
    "# 4k is 0.5 GB and it calculates in 15 min"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-09T12:56:47.805818820Z",
     "start_time": "2023-06-09T12:56:47.504642741Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def three_stages_transform(n: int):\n",
    "    if n == 0:\n",
    "        return 0\n",
    "    if n == 5:\n",
    "        return 2\n",
    "    return 1"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-06-09T12:23:57.950601547Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "coef,freq = pywt.cwt(data[6000], wavelet='morl', scales=range(1,128))\n",
    "coef[:]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(coef, extent=[-1, 1, 1, int(2)], cmap='PRGn', aspect='auto',\n",
    "               vmax=abs(coef).max(), vmin=-abs(coef).max())"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "coef1 = coef[:,:500]\n",
    "coef1\n",
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(coef1, extent=[0, 20, 1, 127], cmap='PRGn', aspect='auto',\n",
    "               vmax=abs(coef1).max(), vmin=-abs(coef1).max())"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
